/*
 * Copyright (c) 2016, Roman Meyta <theshrodingerscat@gmail.com>
 * Copyright (c) 2020-2021 https://gitee.com/fsfzp888
 * All rights reserved
 */

#include <dshow.h>
#include <sstream>

#include "ImageFormats.h"
#include "Logger.h"
#include "SampleGrabber.h"
#include "VideoCapture.h"
#include "VideoDevice.h"

using namespace std;

HRESULT getPin(IBaseFilter *pFilter, PIN_DIRECTION PinDir, IPin **ppPin)
{
    *ppPin           = nullptr;
    IEnumPins *pEnum = nullptr;
    IPin *pPin       = nullptr;

    HRESULT hr = pFilter->EnumPins(&pEnum);
    if (FAILED(hr))
    {
        return hr;
    }

    pEnum->Reset();
    while (pEnum->Next(1, &pPin, nullptr) == S_OK)
    {
        PIN_DIRECTION ThisPinDir;
        pPin->QueryDirection(&ThisPinDir);
        if (ThisPinDir == PinDir)
        {
            pEnum->Release();
            *ppPin = pPin;
            return S_OK;
        }
        pPin->Release();
    }
    pEnum->Release();
    return E_FAIL;
}

VideoCapture::VideoCapture(VideoCaptureCallback callback, VideoCaptureCallback still_callback)
    : m_graph(nullptr), m_capture(nullptr), m_control(nullptr), m_readyForCapture(false), m_activeDeviceNum(0), m_devices()
{
    CoInitialize(nullptr);
    initializeGraph();
    initializeVideo();

    runControl();

    for (auto &device : m_devices)
    {
        device->setCallback(callback);
        device->setStillCallback(still_callback);
    }
}

VideoCapture::~VideoCapture()
{
    for (auto &device : m_devices)
    {
        device->stop();
        disconnectFilters(device.get());
    }
    m_devices.erase(m_devices.begin(), m_devices.end());
    stopControl();

    if (m_control)
    {
        m_control->Release();
        m_control = nullptr;
    }

    if (m_graph)
    {
        m_graph->Release();
        m_graph = nullptr;
    }

    // if (m_capture) {
    //    m_capture->Release();
    //    m_capture = nullptr;
    //}
}

std::vector<std::wstring> VideoCapture::getDevicesNames() const
{
    vector<wstring> names;
    for (auto &device : m_devices)
    {
        names.push_back(device->getFriendlyName());
    }
    return names;
}

std::vector<std::string> VideoCapture::getActiveDeviceResolutions() const
{
    vector<string> resolutions;
    if (m_activeDeviceNum >= m_devices.size())
    {
        return resolutions;
    }

    auto propertiesList = m_devices[m_activeDeviceNum]->getPropertiesList();
    for (auto &properties : propertiesList)
    {
        string formatName = getImageFormatName(properties.pixelFormat);
        stringstream stream;
        stream << properties.width << "x" << properties.height << "@" << formatName;
        string resolution;
        stream >> resolution;
        resolutions.push_back(resolution);
    }
    return resolutions;
}

bool VideoCapture::changeActiveDevice(unsigned deviceNum)
{
    if (!stopCapture())
    {
        return false;
    }
    if (deviceNum >= m_devices.size())
    {
        return false;
    }
    m_activeDeviceNum = deviceNum;
    return true;
}

bool VideoCapture::changeActiveDeviceResolution(unsigned resolutionNum)
{
    if (m_activeDeviceNum >= m_devices.size())
    {
        return false;
    }

    stopCapture();
    if (!stopControl())
    {
        return false;
    }

    auto propertiesList = m_devices[m_activeDeviceNum]->getPropertiesList();
    if (resolutionNum >= propertiesList.size())
    {
        return false;
    }

    if (!m_devices[m_activeDeviceNum]->setCurrentProperties(propertiesList[resolutionNum]))
    {
        return false;
    }

    if (!runControl())
    {
        return false;
    }
    return true;
}

bool VideoCapture::startCapture()
{
    if (m_activeDeviceNum >= m_devices.size())
    {
        return false;
    }
    return m_devices[m_activeDeviceNum]->start();
}

bool VideoCapture::stopCapture()
{
    if (m_activeDeviceNum >= m_devices.size())
    {
        return false;
    }
    return m_devices[m_activeDeviceNum]->stop();
}

bool VideoCapture::runControl()
{
    HRESULT hr = m_control->Run();
    if (hr < 0)
    {
        return false;
    }
    m_readyForCapture = true;

    for (auto &device : m_devices)
    {
        device->stop();
        // IAMVideoControl* pAMVidControl = nullptr;

        // hr = device->m_sourceFilter->QueryInterface(IID_IAMVideoControl, (void**)&pAMVidControl);

        // if (SUCCEEDED(hr))
        //{
        //    // Find the still pin.
        //    IPin* pPin = nullptr;

        //    // pBuild is an ICaptureGraphBuilder2 pointer.

        //    hr = m_capture->FindPin(
        //        device->m_sourceFilter,                  // Filter.
        //        PINDIR_OUTPUT,         // Look for an output pin.
        //        &PIN_CATEGORY_STILL,   // Pin category.
        //        nullptr,                  // Media type (don't care).
        //        FALSE,                 // Pin must be unconnected.
        //        0,                     // Get the 0'th pin.
        //        &pPin                  // Receives a pointer to thepin.
        //    );

        //    if (SUCCEEDED(hr))
        //    {
        //        hr = pAMVidControl->SetMode(pPin, VideoControlFlag_Trigger);
        //        pPin->Release();
        //    }
        //    pAMVidControl->Release();
        //}
    }
    return true;
}

bool VideoCapture::stopControl()
{
    for (auto &device : m_devices)
    {
        device->stop();
    }
    m_readyForCapture = false;
    HRESULT hr        = m_control->Stop();
    if (hr < 0)
    {
        return false;
    }
    return true;
}

bool VideoCapture::initializeGraph()
{
    HRESULT hr = S_FALSE;

    // create the FilterGraph
    hr = CoCreateInstance(CLSID_FilterGraph, nullptr, CLSCTX_INPROC, IID_IFilterGraph2, reinterpret_cast<void **>(&m_graph));
    if (hr < 0 || !m_graph)
    {
        return false;
    }

    // create the CaptureGraphBuilder
    hr = CoCreateInstance(CLSID_CaptureGraphBuilder2, nullptr, CLSCTX_INPROC, IID_ICaptureGraphBuilder2, reinterpret_cast<void **>(&m_capture));
    if (hr < 0 || !m_capture)
    {
        return false;
    }

    // get the controller for the graph
    hr = m_graph->QueryInterface(IID_IMediaControl, reinterpret_cast<void **>(&m_control));
    if (hr < 0 || !m_control)
    {
        return false;
    }
    m_capture->SetFiltergraph(m_graph);
    return true;
}

/**
 * @brief Add id to name for making unique names
 *
 * @param name
 * @param id
 * @return wstring
 */
wstring addIdToName(const std::wstring &name, int id)
{
    stringstream stream;
    stream << "id" << id;
    string uniquePostfix = stream.str();
    wstring newName      = name + wstring(uniquePostfix.begin(), uniquePostfix.end());
    return newName;
}

bool VideoCapture::initializeVideo()
{
    HRESULT hr = S_FALSE;
    VARIANT name;
    wstring filterName;

    ICreateDevEnum *devEnum   = nullptr;
    IEnumMoniker *enumMoniker = nullptr;
    IMoniker *moniker         = nullptr;
    IPropertyBag *pbag        = nullptr;

    // create an enumerator for video input devices
    hr = CoCreateInstance(CLSID_SystemDeviceEnum, nullptr, CLSCTX_INPROC_SERVER, IID_ICreateDevEnum, reinterpret_cast<void **>(&devEnum));
    if (hr < 0 || !devEnum)
    {
        return false;
    }

    hr = devEnum->CreateClassEnumerator(CLSID_VideoInputDeviceCategory, &enumMoniker, 0);
    if (hr < 0 || !enumMoniker)
    {
        return false;
    }

    int devNum = 0;
    while (enumMoniker->Next(1, &moniker, 0) == S_OK)
    {
        ++devNum;
        hr = moniker->BindToStorage(nullptr, nullptr, IID_IPropertyBag, reinterpret_cast<void **>(&pbag));
        if (hr < 0)
        {
            moniker->Release();
            continue;
        }

        VariantInit(&name);

        hr = pbag->Read(L"Description", &name, 0);
        if (hr < 0)
        {
            hr = pbag->Read(L"FriendlyName", &name, 0);
            if (hr < 0)
            {
                moniker->Release();
                continue;
            }
        }

        shared_ptr<VideoDevice> device(new VideoDevice);
        device->m_id = devNum;
        std::wstring wname(name.bstrVal, SysStringLen(name.bstrVal));
        device->m_friendlyName = device->m_filterName = wname;
        device->m_filterName                          = addIdToName(device->m_filterName, device->m_id);

        // add a filter for the device
        hr = m_graph->AddSourceFilterForMoniker(moniker, nullptr, device->m_filterName.c_str(), &device->m_sourceFilter);
        if (hr != S_OK)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to add source filter for moniker of current graph for device %s", device->m_filterName.c_str());
            continue;
        }

        // create a samplegrabber filter for the device
        hr = CoCreateInstance(GUIDHolder::CLSID_SampleGrabber, nullptr, CLSCTX_INPROC_SERVER, IID_IBaseFilter,
                              reinterpret_cast<void **>(&device->m_sampleGrabberFilter));
        if (hr < 0)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to create CLSID_SampleGrabber COM instance for device %s", device->m_filterName.c_str());
            continue;
        }

        // set mediatype on the samplegrabber
        hr = device->m_sampleGrabberFilter->QueryInterface(GUIDHolder::IID_ISampleGrabber, reinterpret_cast<void **>(&device->m_sampleGrabber));
        if (hr != S_OK)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to query sample grabber interface IID_ISampleGrabber for device %s", device->m_filterName.c_str());
            continue;
        }

        hr = CoCreateInstance(GUIDHolder::CLSID_SampleGrabber, nullptr, CLSCTX_INPROC_SERVER, IID_IBaseFilter,
                              reinterpret_cast<void **>(&device->m_sampleStillFilter));
        if (hr < 0)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to create CLSID_SampleGrabber COM instance for device %s", device->m_filterName.c_str());
            continue;
        }

        hr = device->m_sampleStillFilter->QueryInterface(GUIDHolder::IID_ISampleGrabber, reinterpret_cast<void **>(&device->m_sampleStill));
        if (hr != S_OK)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to query sample still interface IID_ISampleGrabber for device %s", device->m_filterName.c_str());
            continue;
        }

        // set device capabilities
        updateDeviceCapabilities(device.get());

        filterName = L"SG" + device->m_filterName;
        m_graph->AddFilter(device->m_sampleGrabberFilter, filterName.c_str());

        // set the media type
        AM_MEDIA_TYPE mt;
        memset(&mt, 0, sizeof(AM_MEDIA_TYPE));

        mt.majortype = MEDIATYPE_Video;
        mt.subtype   = device->getCurrentProperties().pixelFormat;

        hr = device->m_sampleGrabber->SetMediaType(&mt);
        if (hr != S_OK)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to set media type of sample grabber for device %s", device->m_filterName.c_str());
            continue;
        }

        // add the callback to the samplegrabber
        hr = device->m_sampleGrabber->SetCallback(device->m_callbackHandler, 1);
        if (hr != S_OK)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to set callback function of sample grabber for device %s", device->m_filterName.c_str());
            continue;
        }

        AM_MEDIA_TYPE mt2;
        memset(&mt2, 0, sizeof(AM_MEDIA_TYPE));

        mt2.majortype = MEDIATYPE_Video;
        mt2.subtype   = device->getCurrentProperties().pixelFormat;
        hr            = device->m_sampleStill->SetMediaType(&mt);
        if (hr != S_OK)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to set media type of sample still for device %s", device->m_filterName.c_str());
            continue;
        }
        device->m_sampleStill->SetOneShot(false);
        device->m_sampleStill->SetBufferSamples(true);
        hr = device->m_sampleStill->SetCallback(device->m_stillCallbackHandler, 1);
        if (hr != S_OK)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to set callback of sample still for device %s", device->m_filterName.c_str());
            continue;
        }

        // set the null renderer
        hr = CoCreateInstance(GUIDHolder::CLSID_NullRenderer, nullptr, CLSCTX_INPROC_SERVER, IID_IBaseFilter,
                              reinterpret_cast<void **>(&device->m_nullRenderer));
        if (hr < 0)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to create CLSID_NullRenderer COM instance for device %s", device->m_filterName.c_str());
            continue;
        }

        filterName = L"NR" + device->m_filterName;
        m_graph->AddFilter(device->m_nullRenderer, filterName.c_str());

        hr = m_capture->RenderStream(&PIN_CATEGORY_CAPTURE, &MEDIATYPE_Video, device->m_sourceFilter, device->m_sampleGrabberFilter,
                                     device->m_nullRenderer);
        if (hr < 0)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to render videw capture and preview stream for device %s", device->m_filterName.c_str());
            continue;
        }

        filterName = L"SFGN" + device->m_filterName;
        m_graph->AddFilter(device->m_sampleStillFilter, filterName.c_str());

        hr = CoCreateInstance(GUIDHolder::CLSID_NullRenderer, nullptr, CLSCTX_INPROC_SERVER, IID_IBaseFilter,
                              reinterpret_cast<void **>(&device->m_nullRenderer2));
        if (hr < 0)
        {
            pbag->Release();
            moniker->Release();
            LOG_ERROR("Fail to create CLSID_NullRenderer for device %s", device->m_filterName.c_str());
            continue;
        }
        filterName = L"NR2" + device->m_filterName;
        m_graph->AddFilter(device->m_nullRenderer2, filterName.c_str());

        hr = m_capture->RenderStream(&PIN_CATEGORY_STILL, &MEDIATYPE_Video, device->m_sourceFilter, device->m_sampleStillFilter,
                                     device->m_nullRenderer2);
        //if (hr == E_INVALIDARG)
        if (hr < 0)
        {
            // some deivce do not support still pin capture, just remove still filter in this deivce class
            LOG_INFO("Current device %s do not support still pin capture, hardware trigger photograph functionality would be disabled", device->m_filterName.c_str());
            IPin *pPin = nullptr;
            hr = getPin(device->m_sampleStillFilter, PINDIR_INPUT, &pPin);
            if (SUCCEEDED(hr))
            {
                m_graph->Disconnect(pPin);
                pPin->Release();
                pPin = nullptr;
            }
            hr = getPin(device->m_sampleStillFilter, PINDIR_OUTPUT, &pPin);
            if (SUCCEEDED(hr))
            {
                m_graph->Disconnect(pPin);
                pPin->Release();
                pPin = nullptr;
            }
            hr = getPin(device->m_nullRenderer2, PINDIR_INPUT, &pPin);
            if (SUCCEEDED(hr))
            {
                m_graph->Disconnect(pPin);
                pPin->Release();
                pPin = nullptr;
            }
            m_graph->RemoveFilter(device->m_nullRenderer2);
            m_graph->RemoveFilter(device->m_sampleStillFilter);
            device->RemoveStillRelatedFilter();
        }

        // if the stream is started, start capturing immediatly
        LONGLONG start = 0, stop = MAXLONGLONG;
        hr = m_capture->ControlStream(&PIN_CATEGORY_CAPTURE, &MEDIATYPE_Video, device->m_sourceFilter, &start, &stop, 1, 2);
        if (hr < 0)
        {
            pbag->Release();
            moniker->Release();
            continue;
        }

        // reference the graph
        device->m_graph = m_graph;
        m_devices.push_back(device);

        VariantClear(&name);
        pbag->Release();
        moniker->Release();
    }
    enumMoniker->Release();
    devEnum->Release();
    return true;
}

void VideoCapture::disconnectFilters(VideoDevice *device)
{
    if (!device)
    {
        return;
    }

    IPin *pPin = nullptr;
    HRESULT hr = getPin(device->m_sourceFilter, PINDIR_OUTPUT, &pPin);
    if (SUCCEEDED(hr))
    {
        m_graph->Disconnect(pPin);
        pPin->Release();
        pPin = nullptr;
    }

    hr = getPin(device->m_sampleGrabberFilter, PINDIR_INPUT, &pPin);
    if (SUCCEEDED(hr))
    {
        m_graph->Disconnect(pPin);
        pPin->Release();
        pPin = nullptr;
    }

    hr = getPin(device->m_sampleGrabberFilter, PINDIR_OUTPUT, &pPin);
    if (SUCCEEDED(hr))
    {
        m_graph->Disconnect(pPin);
        pPin->Release();
        pPin = nullptr;
    }

    if (device->m_sampleStillFilter)
    {
        hr = getPin(device->m_sampleStillFilter, PINDIR_INPUT, &pPin);
        if (SUCCEEDED(hr))
        {
            m_graph->Disconnect(pPin);
            pPin->Release();
            pPin = nullptr;
        }
        hr = getPin(device->m_sampleStillFilter, PINDIR_OUTPUT, &pPin);
        if (SUCCEEDED(hr))
        {
            m_graph->Disconnect(pPin);
            pPin->Release();
            pPin = nullptr;
        }
    }

    hr = getPin(device->m_nullRenderer, PINDIR_INPUT, &pPin);
    if (SUCCEEDED(hr))
    {
        m_graph->Disconnect(pPin);
        pPin->Release();
        pPin = nullptr;
    }

    if (device->m_nullRenderer2)
    {
        hr = getPin(device->m_nullRenderer2, PINDIR_INPUT, &pPin);
        if (SUCCEEDED(hr))
        {
            m_graph->Disconnect(pPin);
            pPin->Release();
            pPin = nullptr;
        }
    }

    m_graph->RemoveFilter(device->m_nullRenderer);
    if (device->m_nullRenderer2)
    {
        m_graph->RemoveFilter(device->m_nullRenderer2);
    }
    if (device->m_sampleStillFilter)
    {
        m_graph->RemoveFilter(device->m_sampleStillFilter);
    }
    m_graph->RemoveFilter(device->m_sampleGrabberFilter);
    m_graph->RemoveFilter(device->m_sourceFilter);
}

/**
 * @brief Check if media type is known
 *
 * @param type
 * @return true
 */
bool checkMediaType(AM_MEDIA_TYPE *type)
{
    if (type->majortype != MEDIATYPE_Video || type->formattype != FORMAT_VideoInfo)
    {
        return false;
    }

    VIDEOINFOHEADER *pvi = reinterpret_cast<VIDEOINFOHEADER *>(type->pbFormat);
    if (pvi->bmiHeader.biWidth <= 0 || pvi->bmiHeader.biHeight <= 0)
    {
        return false;
    }

    bool isKnownFormat = isKnownImageFormat(type->subtype);
    if (!isKnownFormat)
    {
        return false;
    }
    return true;
}

bool VideoCapture::updateDeviceCapabilities(VideoDevice *device)
{
    if (!device)
    {
        return false;
    }

    HRESULT hr         = S_FALSE;
    AM_MEDIA_TYPE *pmt = nullptr;
    VIDEO_STREAM_CONFIG_CAPS scc;
    IAMStreamConfig *pConfig = nullptr;

    hr = m_capture->FindInterface(&PIN_CATEGORY_CAPTURE, &MEDIATYPE_Video, device->m_sourceFilter, IID_IAMStreamConfig,
                                  reinterpret_cast<void **>(&pConfig));
    if (hr < 0)
    {
        return false;
    }

    int iCount = 0;
    int iSize  = 0;
    hr         = pConfig->GetNumberOfCapabilities(&iCount, &iSize);
    if (hr < 0)
    {
        pConfig->Release();
        return false;
    }

    if (device->m_config)
    {
        device->m_config->Release();
    }
    device->m_config = pConfig;

    for (int iIndex = 0; iIndex < iCount; ++iIndex)
    {
        hr = pConfig->GetStreamCaps(iIndex, &pmt, reinterpret_cast<BYTE *>(&scc));
        if (hr < 0)
        {
            continue;
        }

        if (!checkMediaType(pmt))
        {
            continue;
        }

        VideoDevice::Properties properties;
        VIDEOINFOHEADER *pvi   = reinterpret_cast<VIDEOINFOHEADER *>(pmt->pbFormat);
        properties.mediaType   = *pmt;
        properties.width       = pvi->bmiHeader.biWidth;
        properties.height      = pvi->bmiHeader.biHeight;
        properties.pixelFormat = pmt->subtype;

        IAMVideoControl *pVideoControl = nullptr;
        hr = m_capture->FindInterface(&PIN_CATEGORY_CAPTURE, &MEDIATYPE_Video, device->m_sourceFilter, IID_IAMVideoControl,
                                      reinterpret_cast<void **>(&pVideoControl));
        if (hr < 0)
        {
            continue;
        }

        IPin *pPin = nullptr;
        hr         = getPin(device->m_sourceFilter, PINDIR_OUTPUT, &pPin);
        if (hr < 0)
        {
            continue;
        }

        long supportedModes;
        hr = pVideoControl->GetCaps(pPin, &supportedModes);
        if (hr < 0)
        {
            pPin->Release();
            pVideoControl->Release();
            continue;
        }

        long mode;
        hr = pVideoControl->GetMode(pPin, &mode);
        if (hr < 0)
        {
            pPin->Release();
            pVideoControl->Release();
            continue;
        }

        properties.isFlippedHorizontal = mode & VideoControlFlag_FlipHorizontal;
        properties.isFlippedVertical   = (mode & VideoControlFlag_FlipVertical) >> 1;

        device->m_propertiesList.push_back(properties);

        pPin->Release();
        pVideoControl->Release();
    }

    if (!device->m_propertiesList.empty())
    {
        device->m_currentProperties = *device->m_propertiesList.begin();
    }
    return true;
}
